/**
 * Copyright (c) 2024 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include "../../../common/data_utils.h"
#ifndef ASCENDC_CPU_DEBUG
#include "acl/acl.h"
extern void normalize_custom_do(uint32_t blockDim, void *l2ctrl, void *stream, uint8_t *srcGm, uint8_t *inMeanGm,
    uint8_t *inVarGm, uint8_t *inGammaGm, uint8_t *inBetaGm, uint8_t *outGm, uint8_t *outRstdGm, uint8_t *workspace,
    uint8_t *tiling);
#else
#include "tikicpulib.h"
extern "C" __global__ __aicore__ void normalize_custom(GM_ADDR srcGm, GM_ADDR inMeanGm, GM_ADDR inVarGm,
    GM_ADDR inGammaGm, GM_ADDR inBetaGm, GM_ADDR outGm, GM_ADDR outRstdGm, GM_ADDR workspace, GM_ADDR tiling);
#endif
constexpr uint8_t BLOCK_DIM = 1;
constexpr uint32_t TILINGDATA_SIZE = 6;
constexpr uint32_t WORKSPACE_SIZE = 1024 * 1024;

constexpr uint32_t A_SIZE = 8;
constexpr uint32_t R_SIZE = 64;
constexpr uint32_t R_SIZE_WITH_PAD = 64;

extern uint8_t *GenerateTiling(uint32_t aLength, uint32_t rLength, uint32_t rLengthWithPadding);

static bool CompareResult(const void *outputData, int64_t outSize, std::string goldenName)
{
    void *goldenData;
#ifdef ASCENDC_CPU_DEBUG
    goldenData = (uint8_t *)AscendC::GmAlloc(outSize);
#else
    CHECK_ACL(aclrtMallocHost((void **)(&goldenData), outSize));
#endif
    size_t goldenSize = outSize;
    bool ret = ReadFile("../output/golden_" + goldenName + ".bin", goldenSize, goldenData, goldenSize);
    if (ret) {
        printf("ReadFile golden_%s.bin success!\n", goldenName.c_str());
    } else {
        printf("test failed!\n");
        return false;
    }
    constexpr float EPS = 1e-4;
    int64_t wrongNum = 0;

    for (int i = 0; i < outSize / sizeof(float); i++) {
        float a = (reinterpret_cast<const float *>(outputData))[i];
        float b = (reinterpret_cast<const float *>(goldenData))[i];
        float ae = std::abs(a - b);
        float re = ae / abs(b);
        if (ae > EPS && re > EPS) {
            printf("CompareResult golden_%s.bin failed output is %lf, golden is %lf\n", goldenName.c_str(), a, b);
            wrongNum++;
        }
    }
#ifdef ASCENDC_CPU_DEBUG
    AscendC::GmFree((void *)goldenData);
#else
    CHECK_ACL(aclrtFreeHost(goldenData));
#endif
    if (wrongNum != 0) {
        return false;
    } else {
        printf("CompareResult golden_%s.bin success!\n", goldenName.c_str());
        return true;
    }
}

int32_t main(int32_t argc, char *argv[])
{
    uint32_t blockDim = BLOCK_DIM;
    size_t inputSrcSize = A_SIZE * R_SIZE_WITH_PAD * sizeof(float);
    size_t inputMeanSize = A_SIZE * sizeof(float);
    size_t inputVarSize = A_SIZE * sizeof(float);
    size_t inputGammaSize = R_SIZE_WITH_PAD * sizeof(float);
    size_t inputBetaSize = R_SIZE_WITH_PAD * sizeof(float);
    size_t outputSize = A_SIZE * R_SIZE_WITH_PAD * sizeof(float);
    size_t outputRstdSize = A_SIZE * sizeof(float);

    size_t workspaceSize = WORKSPACE_SIZE;
    size_t tilingFileSize = TILINGDATA_SIZE * sizeof(uint32_t);
    uint8_t *tilingBuf = GenerateTiling(A_SIZE, R_SIZE, R_SIZE_WITH_PAD);

#ifdef ASCENDC_CPU_DEBUG
    uint8_t *inputSrc = (uint8_t *)AscendC::GmAlloc(inputSrcSize);
    uint8_t *inputMean = (uint8_t *)AscendC::GmAlloc(inputMeanSize);
    uint8_t *inputVar = (uint8_t *)AscendC::GmAlloc(inputVarSize);
    uint8_t *inputGamma = (uint8_t *)AscendC::GmAlloc(inputGammaSize);
    uint8_t *inputBeta = (uint8_t *)AscendC::GmAlloc(inputBetaSize);
    uint8_t *output = (uint8_t *)AscendC::GmAlloc(outputSize);
    uint8_t *outputRstd = (uint8_t *)AscendC::GmAlloc(outputRstdSize);

    uint8_t *workspace = (uint8_t *)AscendC::GmAlloc(workspaceSize);
    uint8_t *tiling = (uint8_t *)AscendC::GmAlloc(tilingFileSize);

    ReadFile("../input/input_srcGm.bin", inputSrcSize, inputSrc, inputSrcSize);
    ReadFile("../input/input_inMeanGm.bin", inputMeanSize, inputMean, inputMeanSize);
    ReadFile("../input/input_inVarGm.bin", inputVarSize, inputVar, inputVarSize);
    ReadFile("../input/input_inGammaGm.bin", inputGammaSize, inputGamma, inputGammaSize);
    ReadFile("../input/input_inBetaGm.bin", inputBetaSize, inputBeta, inputBetaSize);

    memcpy_s(tiling, tilingFileSize, tilingBuf, tilingFileSize);

    AscendC::SetKernelMode(KernelMode::AIV_MODE);
    ICPU_RUN_KF(normalize_custom, blockDim, inputSrc, inputMean, inputVar, inputGamma, inputBeta, output, outputRstd,
        workspace, tiling);

    WriteFile("../output/output_outGm.bin", output, outputSize);
    WriteFile("../output/output_outRstdGm.bin", outputRstd, outputRstdSize);

    bool goldenResult = true;
    goldenResult &= CompareResult(output, outputSize, "outGm");
    goldenResult &= CompareResult(outputRstd, outputRstdSize, "outRstdGm");
    if (goldenResult) {
        printf("test pass!\n");
    } else {
        printf("test failed!\n");
    }

    AscendC::GmFree((void *)inputSrc);
    AscendC::GmFree((void *)inputMean);
    AscendC::GmFree((void *)inputVar);
    AscendC::GmFree((void *)inputGamma);
    AscendC::GmFree((void *)inputBeta);
    AscendC::GmFree((void *)output);
    AscendC::GmFree((void *)outputRstd);
    AscendC::GmFree((void *)workspace);
    AscendC::GmFree((void *)tiling);

#else
    CHECK_ACL(aclInit(nullptr));
    aclrtContext context;
    int32_t deviceId = 0;
    CHECK_ACL(aclrtSetDevice(deviceId));
    CHECK_ACL(aclrtCreateContext(&context, deviceId));
    aclrtStream stream = nullptr;
    CHECK_ACL(aclrtCreateStream(&stream));

    uint8_t *srcHost, *inMeanHost, *inVarHost, *inGammaHost, *inBetaHost, *outHost, *outRstdHost, *workspaceHost;
    uint8_t *srcDevice, *inMeanDevice, *inVarDevice, *inGammaDevice, *inBetaDevice, *outDevice, *outRstdDevice,
        *workspaceDevice, *tilingDevice;

    CHECK_ACL(aclrtMallocHost((void **)(&srcHost), inputSrcSize));
    CHECK_ACL(aclrtMallocHost((void **)(&inMeanHost), inputMeanSize));
    CHECK_ACL(aclrtMallocHost((void **)(&inVarHost), inputVarSize));
    CHECK_ACL(aclrtMallocHost((void **)(&inGammaHost), inputGammaSize));
    CHECK_ACL(aclrtMallocHost((void **)(&inBetaHost), inputBetaSize));
    CHECK_ACL(aclrtMallocHost((void **)(&outHost), outputSize));
    CHECK_ACL(aclrtMallocHost((void **)(&outRstdHost), outputRstdSize));
    CHECK_ACL(aclrtMallocHost((void **)(&workspaceHost), workspaceSize));

    CHECK_ACL(aclrtMalloc((void **)&srcDevice, inputSrcSize, ACL_MEM_MALLOC_HUGE_FIRST));
    CHECK_ACL(aclrtMalloc((void **)&inMeanDevice, inputMeanSize, ACL_MEM_MALLOC_HUGE_FIRST));
    CHECK_ACL(aclrtMalloc((void **)&inVarDevice, inputVarSize, ACL_MEM_MALLOC_HUGE_FIRST));
    CHECK_ACL(aclrtMalloc((void **)&inGammaDevice, inputGammaSize, ACL_MEM_MALLOC_HUGE_FIRST));
    CHECK_ACL(aclrtMalloc((void **)&inBetaDevice, inputBetaSize, ACL_MEM_MALLOC_HUGE_FIRST));
    CHECK_ACL(aclrtMalloc((void **)&outDevice, outputSize, ACL_MEM_MALLOC_HUGE_FIRST));
    CHECK_ACL(aclrtMalloc((void **)&outRstdDevice, outputRstdSize, ACL_MEM_MALLOC_HUGE_FIRST));
    CHECK_ACL(aclrtMalloc((void **)&workspaceDevice, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST));
    CHECK_ACL(aclrtMalloc((void **)&tilingDevice, tilingFileSize, ACL_MEM_MALLOC_HUGE_FIRST));

    ReadFile("../input/input_srcGm.bin", inputSrcSize, srcHost, inputSrcSize);
    ReadFile("../input/input_inMeanGm.bin", inputMeanSize, inMeanHost, inputMeanSize);
    ReadFile("../input/input_inVarGm.bin", inputVarSize, inVarHost, inputVarSize);
    ReadFile("../input/input_inGammaGm.bin", inputGammaSize, inGammaHost, inputGammaSize);
    ReadFile("../input/input_inBetaGm.bin", inputBetaSize, inBetaHost, inputBetaSize);

    CHECK_ACL(aclrtMemcpy(workspaceDevice, workspaceSize, workspaceHost, workspaceSize, ACL_MEMCPY_HOST_TO_DEVICE));
    CHECK_ACL(aclrtMemcpy(tilingDevice, tilingFileSize, tilingBuf, tilingFileSize, ACL_MEMCPY_HOST_TO_DEVICE));

    CHECK_ACL(aclrtMemcpy(srcDevice, inputSrcSize, srcHost, inputSrcSize, ACL_MEMCPY_HOST_TO_DEVICE));
    CHECK_ACL(aclrtMemcpy(inMeanDevice, inputMeanSize, inMeanHost, inputMeanSize, ACL_MEMCPY_HOST_TO_DEVICE));
    CHECK_ACL(aclrtMemcpy(inVarDevice, inputVarSize, inVarHost, inputVarSize, ACL_MEMCPY_HOST_TO_DEVICE));
    CHECK_ACL(aclrtMemcpy(inGammaDevice, inputGammaSize, inGammaHost, inputGammaSize, ACL_MEMCPY_HOST_TO_DEVICE));
    CHECK_ACL(aclrtMemcpy(inBetaDevice, inputBetaSize, inBetaHost, inputBetaSize, ACL_MEMCPY_HOST_TO_DEVICE));

    normalize_custom_do(blockDim, nullptr, stream, srcDevice, inMeanDevice, inVarDevice, inGammaDevice, inBetaDevice,
        outDevice, outRstdDevice, workspaceDevice, tilingDevice);

    CHECK_ACL(aclrtSynchronizeStream(stream));

    CHECK_ACL(aclrtMemcpy(outHost, outputSize, outDevice, outputSize, ACL_MEMCPY_DEVICE_TO_HOST));
    CHECK_ACL(aclrtMemcpy(outRstdHost, outputRstdSize, outRstdDevice, outputRstdSize, ACL_MEMCPY_DEVICE_TO_HOST));

    WriteFile("../output/output_outGm.bin", outHost, outputSize);
    WriteFile("../output/output_outRstdGm.bin", outRstdHost, outputRstdSize);

    bool goldenResult = true;
    goldenResult &= CompareResult(outHost, outputSize, "outGm");
    goldenResult &= CompareResult(outRstdHost, outputRstdSize, "outRstdGm");
    if (goldenResult) {
        printf("test pass!\n");
    } else {
        printf("test failed!\n");
    }

    CHECK_ACL(aclrtFree(srcDevice));
    CHECK_ACL(aclrtFree(inMeanDevice));
    CHECK_ACL(aclrtFree(inVarDevice));
    CHECK_ACL(aclrtFree(inGammaDevice));
    CHECK_ACL(aclrtFree(inBetaDevice));
    CHECK_ACL(aclrtFree(outDevice));
    CHECK_ACL(aclrtFree(outRstdDevice));
    CHECK_ACL(aclrtFree(workspaceDevice));
    CHECK_ACL(aclrtFree(tilingDevice));

    CHECK_ACL(aclrtFreeHost(srcHost));
    CHECK_ACL(aclrtFreeHost(inMeanHost));
    CHECK_ACL(aclrtFreeHost(inVarHost));
    CHECK_ACL(aclrtFreeHost(inGammaHost));
    CHECK_ACL(aclrtFreeHost(inBetaHost));
    CHECK_ACL(aclrtFreeHost(outHost));
    CHECK_ACL(aclrtFreeHost(outRstdHost));
    CHECK_ACL(aclrtFreeHost(workspaceHost));

    CHECK_ACL(aclrtDestroyStream(stream));
    CHECK_ACL(aclrtDestroyContext(context));
    CHECK_ACL(aclrtResetDevice(deviceId));
    CHECK_ACL(aclFinalize());
#endif
    free(tilingBuf);
    return 0;
}
